import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import dataclasses
import matplotlib.pyplot as plt
import seaborn as sns
import genjax
from genjax import trace, slash, TFPUniform, Normal
sns.set_theme(style="white")
# Pretty printing.
console = genjax.pretty(width=80)
# Reproducibility.
key = jax.random.PRNGKey(314159)Differentiable programming with GenJAX
TrainCombinator, a generative function combinator which exposes interfaces that make it easy to train learnable parameters with custom optimizers (from e.g. optax) using gradients with respect to the log joint density of the model. To illustrate the interfaces, we showcase a MAP optimization implementation, as well as a Metropolis Adjusted Langevin Algorithm (MALA) implementation. We also show how TrainCombinator can be used to perform maximum likelihood optimization, as well as variational family optimization.
The generative function interface exposes functionality which allows usage of generative functions for differentiable programming. These interfaces are designed to work seamlessly with jax.grad - allowing (even higher-order) gradient computations which are useful for inference algorithms which require gradients and gradient estimators. In this notebook, we’ll describe some of these interfaces - as well as their (current, but not forever) limitations. We’ll walk through an implementation of MAP estimation, as well as the Metropolis-adjusted Langevin algorithm (MALA) using these interfaces.
Gradient interfaces
Because JAX features best-in-class support for higher-order AD, GenJAX exposes interfaces that compose natively with JAX’s interfaces for gradients. The primary interface method which provides jax.grad-compatible functions from generative functions is an interface called unzip.
unzip allows a user to provide a key, and a fixed choice map - and it returns a new key and two closures:
- The first closure is a “score” closure which accepts a choice map as the first argument, and arguments which match the non-
PRNGKeysignature types of the generative function. The score closure returns the exact joint score of the generative function. It computes the exact joint score using an interface calledassess.1 - The second closure is a “retval” closure which accepts a choice map as the first argument, and arguments which match the non-
PRNGKeysignature types of the generative function. The retval closure executes the generative function constrained using the union of the fixed choice map, and the user provided choice map, and returns the return value of the execution. Here, the return value is also provided by invoking theassessinterface.
1 Caveat: assess is not required to return the exact joint score, only an estimate. However, if jax.grad is used on estimates - the resulting thing is not a correct gradient estimator. See the important callout below!
So really, unzip is syntactic sugar over another interface called assess.
assess for exact density evaluation
assess is a generative function interface method which computes log joint density estimates from generative functions. assess requires that a user provide a choice map which completely fills all choices encountered during execution. Otherwise, it errors.2
2 And these errors are thrown at JAX trace time, so you’ll get an exception before runtime.
If a generative function also draws from untraced randomness - assess computes an estimate whose expectation over the distribution of untraced randomness gives the correct log joint density.
When used on generative functions which include untraced randomness, naively applying jax.grad to the closures returned by interfaces described in this notebook do not compute gradient estimators which are unbiased with respect to the true gradients.
Short: don’t use these with untraced randomness. We’re working on alternatives.
A running example
Let’s consider the following model, which we’ll cover in different variations.
# If you don't specify broadcast `in_axes`, you
# should specify number of IID samples via `repeats`.
@genjax.gen(genjax.Map, repeats=100)
def sample_x(x_mu, var):
position = trace("pos", Normal)(x_mu, 1.0)
return position
@genjax.gen(genjax.Map, in_axes=(0, None, None))
def sample_y(x, a, b):
position = trace("pos", Normal)(a * x + b, 1.0)
return position
@genjax.gen
def model():
x_mu = trace("x_mu", TFPUniform)(-3.0, 3.0)
a = trace("a", TFPUniform)(-4.0, 4.0)
b = trace("b", TFPUniform)(-3.0, 3.0)
x = trace("x", sample_x)(x_mu, 1.0)
y = trace("y", sample_y)(x, a, b)
return yAnd, most importantly, visualizations.
def viz(ax, x, y, **kwargs):
sns.scatterplot(x=x, y=y, ax=ax, **kwargs)
f, axes = plt.subplots(3, 3, figsize=(9, 9), sharex=True, sharey=True, dpi=280)
jitted = jax.jit(model.simulate)
trs = []
for ax in axes.flatten():
key, tr = jitted(key, ())
x = tr["x", "pos"]
y = tr["y", "pos"]
trs.append(tr)
viz(ax, x, y, marker=".")
plt.show()A nice diffuse prior over points.
MAP estimation
When it comes to looking at the interfaces, a good first step is gradient-based maximum a posteriori probability (MAP) estimation. Let’s write this using the lowest level interface unzip first:
Now, often we may have a trace in hand, and we just want the first-order gradient with respect to certain random choices (specified by a genjax.Selection). This is a relatively common occurrence - so there’s a higher-level API choice_grad which gives us exactly this thing.3 Here’s MapUpdate using choice_grad.
3 It’s not compositional with jax.grad - but if we need that power, we can just drop back down to use unzip.
@dataclasses.dataclass
class MapUpdate(genjax.Pytree):
selection: genjax.Selection
tau: genjax.FloatArray
def flatten(self):
return (self.tau,), (self.selection,)
def apply(self, key, trace):
args = trace.get_args()
gen_fn = trace.get_gen_fn()
key, forward_gradient_trie = gen_fn.choice_grad(
key, trace, self.selection
)
forward_values, _ = self.selection.filter(trace)
forward_values = forward_values.strip()
forward_values = jtu.tree_map(
lambda v1, v2: v1 + self.tau * v2,
forward_values,
forward_gradient_trie,
)
argdiffs = tuple(map(genjax.Diff.no_change, args))
key, (_, _, new_trace, _) = gen_fn.update(
key, trace, forward_values, argdiffs
)
return key, (new_trace, True)
def __call__(self, key, trace):
return self.apply(key, trace)
map_update = MapUpdateSimple, concise - works with any generative function whose choices specified by MapUpdate.selection support gradients on the joint logpdf.
From the Pytree interface, any instance of MapUpdate has a static selection, and a tau (which determines the gradient step size) which can be dynamic.4
4 If this is your first time seeing the Pytree interface, note that it’s defined by the flatten interface - which allows us to specify runtime vs. trace time data in Pytree structures.
Because MapUpdate is a Pytree, in inference code, we’d just construct MapUpdate before calling it - and we can do this on either side of the JAX API boundary.5
5 E.g. outside of a jax.jit transform, inside - it’s all okay.
update = map_update(genjax.select(["x_mu", "a", "b"]), 1e-4)
updateMapUpdate ├── selection │ └── BuiltinSelection │ └── trie │ └── Trie │ ├── :x_mu │ │ └── AllSelection │ ├── :a │ │ └── AllSelection │ └── :b │ └── AllSelection └── tau └── (lit) 0.0001
Let’s take a sampled piece of data, extract the ("x", "pos") and ("y", "pos") addresses, and then use MAP optimization to estimate the mode of the posterior.
tr = trs[2]
selection = genjax.select(["x", "y"])
chm, _ = selection.filter(tr.strip())
chmBuiltinChoiceMap └── trie └── Trie ├── :x │ └── VectorChoiceMap │ ├── indices │ │ └── i32[100] │ └── inner │ └── BuiltinChoiceMap │ └── trie │ └── Trie │ └── :pos │ └── ValueChoiceMap │ └── value │ └── f32[100] └── :y └── VectorChoiceMap ├── indices │ └── i32[100] └── inner └── BuiltinChoiceMap └── trie └── Trie └── :pos └── ValueChoiceMap └── value └── f32[100]
x = chm["x", "pos"]
y = chm["y", "pos"]
fig_data, ax_data = plt.subplots(figsize=(3, 3), dpi=140)
viz(ax_data, x, y, marker=".")If we apply MapUpdate, we take a single optimization step:
key, (_, tr) = jax.jit(model.importance)(key, chm, ())
key, (tr, _) = update(key, tr)We can use scan to apply MapUpdate repeatedly.
def chain(key, tr):
def _inner(carry, _):
key, tr = carry
key, (tr, _) = update(key, tr)
return (key, tr), ()
(key, tr), _ = jax.lax.scan(_inner, (key, tr), None, length=2000)
return key, trjitted = jax.jit(chain)
key, tr = jitted(key, tr)Now, we can plot the polynomial described by "a" and "b", with evaluation points generated around the estimated "x_mu".
def polynomial_at_x(x, coefficients):
basis_values = jnp.array([1.0, x])
polynomial_value = jnp.sum(coefficients * basis_values)
return polynomial_value
jitted = jax.jit(jax.vmap(polynomial_at_x, in_axes=(0, None)))
def plot_polynomial_values(ax, x, coefficients, **kwargs):
v = jitted(x, coefficients)
ax.scatter(x, v, **kwargs)
a = tr["a"]
b = tr["b"]
x_mu = tr["x_mu"]
key, sub_key = jax.random.split(key)
evaluation_points = x_mu + jax.random.normal(sub_key, shape=(1000,))
coefficients = jnp.array([b, a])
plot_polynomial_values(
ax_data,
evaluation_points,
coefficients,
marker=".",
color="gold",
alpha=0.05,
)
fig_dataExposing learnable modules with TrainCombinator
For learning and variational inference, learnable parameters of model families are an important feature. In GenJAX, we expose a lightweight generative function wrapper around generative functions which accept learnable_params as a last argument - this wrapper is called TrainCombinator.
params = {"x_mu": 0.0, "a": 0.3, "b": 0.4}
@genjax.gen(genjax.TrainCombinator, params=params)
def model(params):
x_mu = params["x_mu"]
a = params["a"]
b = params["b"]
x = genjax.trace("x", genjax.Normal)(x_mu, 1.0)
return genjax.trace("y", genjax.Normal)(a * x + b, 1.0)
modelTrainCombinator ├── inner │ └── BuiltinGenerativeFunction │ └── source │ └── <function model> └── params └── dict ├── x_mu │ └── (lit) 0.0 ├── a │ └── (lit) 0.3 └── b └── (lit) 0.4
TrainCombinator is a module-like abstraction which closes over the parameter store passed in and initialized by the constructor. When we call TrainCombinator, we don’t need to provide the params argument - it’s handled automatically.
key, tr = model.simulate(key, ())
trBuiltinTrace ├── gen_fn │ └── BuiltinGenerativeFunction │ └── source │ └── <function model> ├── args │ └── tuple │ └── dict │ ├── x_mu │ │ └── (lit) 0.0 │ ├── a │ │ └── (lit) 0.3 │ └── b │ └── (lit) 0.4 ├── retval │ └── f32[] ├── choices │ └── Trie │ ├── :x │ │ └── DistributionTrace │ │ ├── gen_fn │ │ │ └── _Normal │ │ ├── args │ │ │ └── tuple │ │ │ ├── (lit) 0.0 │ │ │ └── (lit) 1.0 │ │ ├── value │ │ │ └── f32[] │ │ └── score │ │ └── f32[] │ └── :y │ └── DistributionTrace │ ├── gen_fn │ │ └── _Normal │ ├── args │ │ └── tuple │ │ ├── f32[] │ │ └── (lit) 1.0 │ ├── value │ │ └── f32[] │ └── score │ └── f32[] ├── cache │ └── Trie └── score └── f32[]
TrainCombinator exposes a convenient interface to a specialized scoring function which accepts params evaluation points, and returns the model logpdf.6
6 Note: because this interface method returns a function, we cannot JIT it. However, if we use it to produce a closure and then use that closure inside of code which we JIT, it’s fine. Producing this closure is also speedy! As long as an object of function type doesn’t try to escape across the JAX API boundary.
key, logpdf = model.score_params(key, tr, params)
logpdfArray(-3.007322, dtype=float32)
This interface supports batching out of the box.
key, *sub_keys = jax.random.split(key, 100 + 1)
sub_keys = jnp.array(sub_keys)
sub_keys, tr = jax.vmap(model.simulate)(sub_keys, ())
_, logpdf = jax.vmap(model.score_params, in_axes=(0, 0, None))(
sub_keys, tr, params
)
logpdfArray([-1.9976285, -3.5780087, -1.9807117, -3.2873745, -3.7859497, -2.6952581, -2.0800304, -3.6221423, -3.8627763, -2.0281572, -2.16537 , -2.1819763, -2.4236362, -2.2720413, -4.0935936, -2.8679984, -2.2018921, -2.8351343, -3.131372 , -5.0481167, -2.612084 , -4.7577085, -3.059651 , -3.8924067, -2.0701456, -2.5086784, -2.8568597, -2.5822752, -2.3407137, -3.35099 , -3.9029412, -1.8635211, -1.9603972, -2.4779296, -1.9830978, -3.5628843, -3.360316 , -2.33081 , -3.569415 , -2.6337898, -2.9075725, -2.0472188, -2.127375 , -3.4633002, -2.5945296, -2.6425881, -2.155263 , -2.1422853, -2.2873197, -4.2150993, -1.8974092, -1.9259053, -2.4462447, -1.8541551, -2.7611914, -1.9701724, -3.810885 , -2.9374416, -2.1419594, -2.700367 , -4.6069875, -3.0089617, -1.8390343, -2.0409722, -2.1649864, -2.3391469, -2.3591323, -3.250299 , -2.9457328, -4.334495 , -3.0190947, -4.1563606, -2.7005963, -3.1220098, -1.9379404, -2.7643821, -2.4005075, -4.619805 , -2.2810106, -3.459144 , -2.371623 , -2.2807765, -3.37915 , -2.0856595, -1.8977728, -1.8606322, -2.6173737, -2.373567 , -2.30009 , -1.8428204, -2.0290344, -2.6710665, -3.5239663, -2.3498297, -2.5525582, -2.0731103, -2.8109097, -2.3445606, -3.8245656, -2.5873551], dtype=float32)
We make extensive use of batch evaluation in variational inference. For now, let’s consider maximum likelihood learning and see the other TrainCombinator interfaces.
Automatic differentiation variational inference
In this section, we’ll show how we can use the gradient interfaces to implement Automatic differentiation variational inference.